import os, sys
import numpy as np
import torch
from einops import rearrange
from PIL import Image
import torchvision.transforms as transforms
from config import *
import wandb
import datetime
import argparse
import os.path as op


import config as cfg
from dataset import create_EEG_dataset
from ldm_for_eeg import eLDM

def to_image(img):
    if img.shape[-1] != 3:
        img = rearrange(img, 'c h w -> h w c')
    img = 255. * img
    return Image.fromarray(img.astype(np.uint8))

def channel_last(img):
    if img.shape[-1] == 3:
        return img
    return rearrange(img, 'c h w -> h w c')

def normalize(img):
    if img.shape[-1] == 3:
        img = rearrange(img, 'h w c -> c h w')
    img = torch.tensor(img)
    img = img * 2.0 - 1.0 # to -1 ~ 1
    return img

def wandb_init(config):
    wandb.init( project="dreamdiffusion",
                group='eval',
                anonymous="allow",
                config=config,
                reinit=True)

class random_crop:
    def __init__(self, size, p):
        self.size = size
        self.p = p
    def __call__(self, img):
        if torch.rand(1) < self.p:
            return transforms.RandomCrop(size=(self.size, self.size))(img)
        return img

def get_args_parser():
    parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False)
    # project parameters
    parser.add_argument('--root_path', type=str, default=cfg.project_path)
    parser.add_argument('--dataset', type=str, default='WM')
    parser.add_argument('--subject', type=int, default=0)
    parser.add_argument('--model_date', type=str, default='19-05-2024-00-47-48')
    parser.add_argument('--model_path', type=str, default='')

    return parser


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    root = args.root_path
    target = args.dataset

    # args.eeg_signals_path = os.path.join(args.root_path, f'cvpr2017_wm/eeg_signals_raw_with_mean_std_full_64ch.pth')
    # args.splits_path = os.path.join(args.root_path, 'cvpr2017_wm/block_splits_by_image_single.pth')

    if args.model_path == '':
        args.model_path = os.path.join(root, 'exps', 'results',
                                       'diffusion_finetune', args.dataset, f'sub_{args.subject}',
                                       args.model_date, 'checkpoint_best.pth')
    sd = torch.load(args.model_path, map_location='cpu')
    config = sd['config']
    
    # revise relative path to absolute path
    for k, v in config.__dict__.items():
        if 'path' in k:
            if v is not None and not op.isabs(v):
                config.__dict__[k] = op.join(root, v.split('../')[-1])
    config.root_path = root
    print(config.__dict__)
    
    model_path_dir = os.path.dirname(args.model_path)

    output_path = os.path.join(model_path_dir, 'eval')
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    crop_pix = int(config.crop_ratio*config.img_size)
    img_transform_train = transforms.Compose([
        normalize,
        transforms.Resize((512, 512)),
        channel_last
    ])
    img_transform_test = transforms.Compose([
        normalize, transforms.Resize((512, 512)), 
        channel_last
    ])

    config.imagenet_path = cfg.image_dir
    config.pretrain_gm_path = op.join(cfg.project_path, 'pretrains')
    _, _, dataset_test =  create_EEG_dataset(dataset=config.dataset,
                             image_transform=[img_transform_train, img_transform_test],
                             subject=config.subject,
                             imagenet_path=config.imagenet_path)

    # num_voxels = dataset_test.num_voxels
    print(len(dataset_test))
    # prepare pretrained mae 
    # eegencoder_metafile = torch.load(config.eeg_encoder_path, map_location='cpu')
    # create generateive model
    generative_model = eLDM(metafile=None, device=device,
                            pretrain_root=config.pretrain_gm_path,
                            logger=config.logger,
                            ddim_steps=config.ddim_steps,
                            global_pool=config.global_pool,
                            use_time_cond=config.use_time_cond,
                            clip_tune=config.clip_tune,
                            cls_tune=config.cls_tune)

    # m, u = model.load_state_dict(pl_sd, strict=False)
    generative_model.model.load_state_dict(sd['model_state_dict'], strict=True)
    print('load ldm successfully')
    state = sd['state']
    os.makedirs(output_path, exist_ok=True)
    # grid, _ = generative_model.generate(dataset_train, config.num_samples,
    #             config.ddim_steps, config.HW, 10) # generate 10 instances
    # grid_imgs = Image.fromarray(grid.astype(np.uint8))
    #
    # grid_imgs.save(os.path.join(output_path, f'./samples_train.png'))

    grid, samples = generative_model.generate(dataset_test, config.num_samples, 
                config.ddim_steps, config.HW, limit=None, state=state, output_path=output_path) # generate 10 instances
    grid_imgs = Image.fromarray(grid.astype(np.uint8))


    grid_imgs.save(os.path.join(output_path, f'./samples_test.png'))
